import time
import uuid
from typing import Dict, Any

from app.base.logger import setup_logger
from app.service.knowledge_graph_service import knowledge_graph_service
from app.service.documentation_service import documentation_service
from fastmcp import FastMCP, Context

# 使用新的日志配置
logger = setup_logger("mcp_service")

# 存储会话信息
sessions = {}

class MCPService:
    """
    FastMCP服务实现，用于对话和知识图谱查询
    """
    def __init__(self):
        self.mcp = FastMCP(name="知识图谱查询服务")
        self._register_tools()
    
    def _register_tools(self):
        """注册所有MCP工具"""
        
        @self.mcp.tool()
        def start_session(ctx: Context = None) -> Dict[str, Any]:
            """
            开启一个新的对话会话
            
            Returns:
                Dict: 包含会话ID的字典
            """
            try:
                logger.info("开启新对话会话")
                
                # 确保知识图谱服务已初始化
                if not knowledge_graph_service.initialized:
                    success = knowledge_graph_service.initialize()
                    if not success:
                        error_msg = knowledge_graph_service.get_error_message()
                        logger.error(f"初始化知识图谱服务失败: {error_msg}")
                        return {"error": f"无法开启会话: {error_msg}"}
                
                # 生成会话ID
                session_id = str(uuid.uuid4())
                
                # 获取当前时间戳
                current_timestamp = int(time.time())
                
                # 存储会话信息
                sessions[session_id] = {
                    "id": session_id,
                    "created_at": current_timestamp,
                    "history": [],
                    "query_map": {}  # 查询与结果的映射
                }
                
                logger.info(f"新会话已创建: ID={session_id}")
                return {"session_id": session_id}
                
            except Exception as e:
                logger.exception(f"开启会话失败: {str(e)}")
                return {"error": f"开启会话失败: {str(e)}"}
        
        @self.mcp.tool()
        def echo(text: str, ctx: Context = None) -> Dict[str, Any]:
            """
            简单的回声测试工具
            
            Args:
                text: 要回显的文本
                ctx: FastMCP上下文对象
            
            Returns:
                Dict: 包含回声的字典
            """
            return {"echo": text}
        
        @self.mcp.tool()
        def query_knowledge_graph(
            session_id: str,
            original_user_query_text: str,
            query_text: str,
            repo_id: str,
            top_k: int = 3,
            ctx: Context = None,
            mcp_resort: bool = False
        ) -> Dict[str, Any]:
            """
            查询知识图谱
            查到知识图谱结果后，你必须让用户确认使用哪个节点，才能继续往下走
            Args:
                session_id: 会话ID
                original_user_query_text: 用户原始查询文本
                query_text: 查询文本(最好是精确到你想要查询的类\方法的全限定名)
                repo_id: 仓库的远程地址例如git@gitee.com:whrdckf3/maling-query.git，从现有上下文中获取，如果没有则执行git命令获取，命令获取失败则传空字符串
                top_k: 每种节点类型返回的最大结果数,最小为3
                ctx: FastMCP上下文对象
                mcp_resort: 是否支持MCP resort
                
            Returns:
                Dict: 包含查询结果的字典
            """
            try:
                logger.info(f"查询知识图谱: 会话ID={session_id}, 仓库ID={repo_id}, 查询={query_text[:50]}...")
                
                # 验证会话
                if session_id not in sessions:
                    logger.warning(f"无效的会话ID: {session_id}")
                    return {"error": "无效的会话ID，请先开启会话"}
                
                # 检查repo_id是否为空
                if not repo_id:
                    logger.warning("仓库ID为空")
                    return {"error": "仓库ID不能为空，请提供有效的仓库ID"}
                
                # 确保知识图谱服务已初始化
                if not knowledge_graph_service.initialized:
                    success = knowledge_graph_service.initialize()
                    if not success:
                        error_msg = knowledge_graph_service.get_error_message()
                        logger.error(f"初始化知识图谱服务失败: {error_msg}")
                        return {"error": f"无法执行查询: {error_msg}"}
                
                # 将请求添加到会话历史
                session = sessions[session_id]
                
                # 获取当前时间戳
                current_timestamp = int(time.time())
                
                query_entry = {
                    "type": "query",
                    "timestamp": current_timestamp,
                    "original_user_query": original_user_query_text,
                    "query": query_text,
                    "repo_id": repo_id
                }
                session["history"].append(query_entry)
                
                # 执行搜索
                results = knowledge_graph_service.search_knowledge_graph(
                    query_text=query_text,
                    repo_id=repo_id,
                    top_k=top_k,
                    related_limit=3  # 限制关联节点数量
                )
                
                # 检查是否有错误
                if "error" in results:
                    logger.error(f"查询知识图谱失败: {results['error']}")
                    return {"error": results["error"]}
                
                # 将结果添加到会话历史
                current_timestamp = int(time.time())
                result_entry = {
                    "type": "result",
                    "timestamp": current_timestamp,
                    "query": query_text,
                    "results": results
                }
                session["history"].append(result_entry)
                
                # 更新查询映射
                query_key = f"{query_text}_{current_timestamp}"
                session["query_map"][query_key] = results
                
                # 格式化返回结果（适配层级化结构）
                def format_node(node_data):
                    """格式化单个节点数据"""
                    # 基础字段
                    formatted_node = {
                        "id": node_data.get("id", ""),
                        "node_type": node_data.get("node_type", ""),
                        "full_name": node_data.get("full_name", ""),
                        "is_direct_match": node_data.get("is_direct_match", False)
                    }

                    # 添加相似度分数（如果存在）
                    if "rank" in node_data:
                        formatted_node["rank"] = node_data["rank"]

                    # 添加关系信息（如果存在）
                    if "relation" in node_data:
                        formatted_node["relation"] = node_data["relation"]

                    # 根据节点类型添加特定字段
                    if node_data.get("node_type") == "class_summary":
                        # 类摘要节点，添加class_info字段
                        formatted_node["class_info"] = node_data.get("class_info", {})
                    else:
                        # 其他节点类型，添加summary字段
                        formatted_node["summary"] = node_data.get("summary", "")

                    return formatted_node

                formatted_results = []
                total_related_count = 0

                for primary_node in results.get("results", []):
                    # 格式化第一层节点（向量匹配的节点）
                    formatted_primary = format_node(primary_node)

                    # 格式化关联节点
                    formatted_related = []
                    for related_node in primary_node.get("related_nodes", []):
                        formatted_related_node = format_node(related_node)
                        formatted_related.append(formatted_related_node)
                        total_related_count += 1

                    # 添加关联节点到第一层节点
                    formatted_primary["related_nodes"] = formatted_related
                    formatted_results.append(formatted_primary)

                logger.info(f"查询知识图谱完成: 找到 {len(formatted_results)} 个第一层节点，"
                           f"总计 {total_related_count} 个关联节点")
                return {
                    "query_text": query_text,
                    "results": formatted_results
                }
                
            except Exception as e:
                logger.exception(f"查询知识图谱失败: {str(e)}")
                return {"error": f"查询知识图谱失败: {str(e)}"}
        
        @self.mcp.tool()
        def get_node_content(
            session_id: str,
            node_id: str,
            repo_id: str,
            ctx: Context = None
        ) -> Dict[str, Any]:
            """
            获取节点内容
            
            Args:
                session_id: 会话ID
                node_id: 节点ID
                repo_id: 仓库的远程地址例如git@gitee.com:whrdckf3/maling-query.git，从现有上下文中获取，如果没有则执行git命令获取，命令获取失败则传空字符串
                ctx: FastMCP上下文对象
            
            Returns:
                Dict: 包含节点内容的字典
            """
            try:
                logger.info(f"获取节点内容: 会话ID={session_id}, 节点ID={node_id}, 仓库ID={repo_id}")
                
                # 验证会话
                if session_id not in sessions:
                    logger.warning(f"无效的会话ID: {session_id}")
                    return {"error": "无效的会话ID，请先开启会话"}
                
                # 检查repo_id是否为空
                if not repo_id:
                    logger.warning("仓库ID为空")
                    return {"error": "仓库ID不能为空，请提供有效的仓库ID"}
                
                # 确保知识图谱服务已初始化
                if not knowledge_graph_service.initialized:
                    success = knowledge_graph_service.initialize()
                    if not success:
                        error_msg = knowledge_graph_service.get_error_message()
                        logger.error(f"初始化知识图谱服务失败: {error_msg}")
                        return {"error": f"无法获取节点内容: {error_msg}"}
                
                # 将请求添加到会话历史
                session = sessions[session_id]
                current_timestamp = int(time.time())
                request_entry = {
                    "type": "content_request",
                    "timestamp": current_timestamp,
                    "node_id": node_id,
                    "repo_id": repo_id  # 添加repo_id到历史记录
                }
                session["history"].append(request_entry)
                
                # 获取节点内容
                node_data = knowledge_graph_service.get_node_content(node_id, repo_id)
                
                # 检查是否有错误
                if "error" in node_data:
                    logger.error(f"获取节点内容失败: {node_data['error']}")
                    return {"error": node_data["error"]}
                
                # 将结果添加到会话历史
                current_timestamp = int(time.time())
                content_entry = {
                    "type": "content_result",
                    "timestamp": current_timestamp,
                    "node_id": node_id,
                    "data": node_data
                }
                session["history"].append(content_entry)
                
                # 格式化返回结果
                formatted_result = {
                    "id": node_data.get("id", ""),
                    "node_type": node_data.get("node_type", ""),
                    "properties": node_data.get("properties", {})
                }
                
                logger.info(f"获取节点内容完成: 节点ID={node_id}, 类型={node_data.get('node_type', '')}")
                return formatted_result
                
            except Exception as e:
                logger.exception(f"获取节点内容失败: {str(e)}")
                return {"error": f"获取节点内容失败: {str(e)}"}
        
        @self.mcp.tool()
        def close_session(
            session_id: str,
            ctx: Context = None
        ) -> Dict[str, Any]:
            """
            关闭会话
            
            Args:
                session_id: 会话ID
                ctx: FastMCP上下文对象
            
            Returns:
                Dict: 包含操作状态的字典
            """
            try:
                logger.info(f"关闭会话: ID={session_id}")
                
                # 验证会话
                if session_id not in sessions:
                    logger.warning(f"无效的会话ID: {session_id}")
                    return {"error": "无效的会话ID"}
                
                # 移除会话
                session = sessions.pop(session_id, None)
                
                # 返回操作状态
                logger.info(f"会话已关闭: ID={session_id}")
                return {"status": "success", "message": "会话已关闭"}
                
            except Exception as e:
                logger.exception(f"关闭会话失败: {str(e)}")
                return {"error": f"关闭会话失败: {str(e)}"}

        @self.mcp.tool()
        def query_aggregated_documentation(
            session_id: str,
            repo_id: str,
            branch_name: str,
            original_user_query_text: str,
            query_text: str,
            topk: int = 5,
            ctx: Context = None
        ) -> Dict[str, Any]:
            """
            查询聚合说明书接口

            Args:
                session_id: 会话ID（必需，需要先调用start_session获取）
                repo_id: 仓库ID
                branch_name: 分支名称
                original_user_query_text: 原始用户查询文本
                query_text: 处理后的查询文本
                topk: 返回结果数量，默认5
                ctx: FastMCP上下文对象

            Returns:
                Dict: 包含说明书内容的字典
            """
            try:
                logger.info(f"查询聚合说明书: session_id={session_id}, repo_id={repo_id}, "
                           f"branch_name={branch_name}, query={query_text[:50]}..., topk={topk}")

                # 验证会话
                if session_id not in sessions:
                    logger.warning(f"无效的会话ID: {session_id}")
                    return {"error": "无效的会话ID，请先开启会话"}

                # 获取session信息
                session = sessions[session_id]
                logger.info(f"使用会话: {session_id}, 创建时间: {session.get('created_at', 'Unknown')}")

                # 确保文档服务已初始化
                if not documentation_service.initialized:
                    success = documentation_service.initialize()
                    if not success:
                        error_msg = documentation_service.get_error_message()
                        logger.error(f"初始化文档服务失败: {error_msg}")
                        return {"error": f"无法查询文档: {error_msg}"}

                # 搜索聚合说明书
                search_results = documentation_service.search_aggregated_documents(
                    repo_id=repo_id,
                    branch_name=branch_name,
                    query_text=query_text,
                    topk=topk
                )

                if not search_results:
                    logger.info("未找到匹配的聚合说明书")
                    return {
                        "session_id": session_id,
                        "original_query": original_user_query_text,
                        "processed_query": query_text,
                        "results": [],
                        "message": "未找到匹配的聚合说明书"
                    }

                # 批量获取所有匹配文档的内容
                document_ids = [result.document_id for result in search_results]
                doc_contents = documentation_service.get_documentation_content_batch(document_ids)

                if doc_contents:
                    logger.info(f"成功批量获取聚合说明书内容: 获取到 {len(doc_contents)} 个文档")

                    # 创建文档ID到内容的映射
                    content_map = {str(doc.get('id', '')): doc for doc in doc_contents}

                    # 构建结果，包含完整内容的文档
                    enriched_results = []
                    for result in search_results:
                        doc_content = content_map.get(result.document_id, {})
                        enriched_results.append({
                            "document_id": result.document_id,
                            "score": result.score,
                            "document_type": result.document_type,
                            "content": doc_content.get('content', ''),
                            "content_preview": result.content[:200] + "..." if len(result.content) > 200 else result.content
                        })

                    return {
                        "session_id": session_id,
                        "original_query": original_user_query_text,
                        "processed_query": query_text,
                        "total_results": len(enriched_results),
                        "results": enriched_results
                    }
                else:
                    logger.error(f"无法批量获取文档内容: document_ids={document_ids}")
                    return {
                        "session_id": session_id,
                        "original_query": original_user_query_text,
                        "processed_query": query_text,
                        "error": "找到匹配文档但无法获取内容",
                        "document_ids": document_ids
                    }

            except Exception as e:
                logger.exception(f"查询聚合说明书失败: {str(e)}")
                return {"error": f"查询聚合说明书失败: {str(e)}"}

        @self.mcp.tool()
        def query_process_documentation(
            session_id: str,
            repo_id: str,
            branch_name: str,
            document_ids: list[str],
            ctx: Context = None
        ) -> Dict[str, Any]:
            """
            查询流程说明书接口（支持批量查询）

            Args:
                session_id: 会话ID（必需，需要先调用start_session获取）
                repo_id: 仓库ID
                branch_name: 分支名称
                document_ids: 文档ID列表
                ctx: FastMCP上下文对象

            Returns:
                Dict: 包含说明书内容和关联方法的字典
            """
            try:
                logger.info(f"批量查询流程说明书: session_id={session_id}, repo_id={repo_id}, "
                           f"branch_name={branch_name}, document_ids={document_ids}")

                # 验证session是否存在
                if session_id not in sessions:
                    logger.error(f"无效的会话ID: {session_id}")
                    return {
                        "success": False,
                        "message": f"无效的会话ID: {session_id}，请先调用start_session创建会话",
                        "data": None
                    }

                # 获取session信息
                session = sessions[session_id]
                logger.info(f"使用会话: {session_id}, 创建时间: {session.get('created_at', 'Unknown')}")

                # 确保文档服务已初始化
                if not documentation_service.initialized:
                    success = documentation_service.initialize()
                    if not success:
                        error_msg = documentation_service.get_error_message()
                        logger.error(f"初始化文档服务失败: {error_msg}")
                        return {"error": f"无法查询文档: {error_msg}"}

                # 批量获取流程说明书内容
                doc_contents = documentation_service.get_process_documentation_batch(document_ids)

                if doc_contents:
                    logger.info(f"成功批量获取流程说明书内容: 获取到 {len(doc_contents)} 个文档")

                    # 构建批量结果
                    results = []
                    for doc_content in doc_contents:
                        results.append({
                            "id": doc_content.get('id', ''),
                            "content": doc_content.get('content', ''),
                            "entryPointId": doc_content.get('entryPointId', ''),
                            "entryPointName": doc_content.get('entryPointName', '')
                        })

                    return {
                        "success": True,
                        "message": "查询成功",
                        "session_id": session_id,
                        "total_results": len(results),
                        "data": results
                    }
                else:
                    logger.error(f"无法批量获取流程文档内容: document_ids={document_ids}")
                    return {
                        "success": False,
                        "message": "无法获取流程文档内容",
                        "session_id": session_id,
                        "document_ids": document_ids,
                        "data": None
                    }

            except Exception as e:
                logger.exception(f"查询流程说明书失败: {str(e)}")
                return {"error": f"查询流程说明书失败: {str(e)}"}

    def get_mcp_server(self):
        """获取FastMCP服务器实例"""
        return self.mcp

# 创建全局实例
mcp_service = MCPService()
mcp_server = mcp_service.get_mcp_server()

# 获取HTTP应用
def get_http_app(prefix="/mcp"):
    """获取FastMCP HTTP应用"""
    # 注意：path参数是指定在HTTP应用内部挂载的路径，而不是FastAPI挂载路径
    logger.info(f"创建FastMCP HTTP应用，内部路径: /")
    
    # 使用FastMCP的HTTP应用创建方法
    # 确保设置正确的内容类型和响应头
    http_app = mcp_server.http_app(transport="sse", path="/")
    
    # 移除中间件，可能导致兼容性问题
    # @http_app.middleware("http")
    # async def add_sse_headers(request, call_next):
    #     response = await call_next(request)
    #     # 为SSE响应设置正确的Content-Type
    #     if "event-stream" in response.headers.get("content-type", ""):
    #         response.headers["Content-Type"] = "text/event-stream"
    #         response.headers["Cache-Control"] = "no-cache"
    #         response.headers["Connection"] = "keep-alive"
    #         response.headers["X-Accel-Buffering"] = "no"  # 对Nginx有用
    #     return response
    
    return http_app

# 挂载FastMCP到FastAPI应用（新版本FastMCP 2.3.4推荐方法）
def mount_to_app(app, prefix="/mcp"):    
    """将MCP服务挂载到FastAPI应用"""
    logger.info(f"使用新版本挂载方法，挂载路径: {prefix}")
    # 创建FastMCP的HTTP应用，注意内部路径使用 / 而不是 /mcp    
    http_app = get_http_app()
    # 移除lifespan上下文传递，可能导致兼容性问题 
    # app.router.lifespan_context = http_app.router.lifespan_context        
    # 使用FastAPI的mount方法挂载到指定路径    
    logger.info(f"挂载MCP应用到FastAPI路径: {prefix}")
    app.mount(prefix, http_app)
    return app

# 直接运行时使用stdio接口
if __name__ == "__main__":
    from fastmcp.cli import run_stdio
    run_stdio(mcp_server) 